{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rOvvWAVTkMR7"
      },
      "source": [
        "# Introduction\n",
        "\n",
        "Welcome to the **Few Shot Object Detection for TensorFlow Lite** Colab. Here, we demonstrate fine tuning of a SSD architecture (pre-trained on COCO) on very few examples of a *novel* class. We will then generate a (downloadable) TensorFlow Lite model for on-device inference.\n",
        "\n",
        "**NOTE:** This Colab is meant for the few-shot detection use-case. To train a model on a large dataset, please follow the [TF2 training](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_training_and_evaluation.md#training) documentation and then [convert](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md) the model to TensorFlow Lite."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3U2sv0upw04O"
      },
      "source": [
        "# Set Up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vPs64QA1Zdov"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H0rKBV4uZacD"
      },
      "outputs": [],
      "source": [
        "# Support for TF2 models was added after TF 2.3.\n",
        "!pip install tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oi28cqGGFWnY"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pathlib\n",
        "\n",
        "# Clone the tensorflow models repository if it doesn't already exist\n",
        "if \"models\" in pathlib.Path.cwd().parts:\n",
        "  while \"models\" in pathlib.Path.cwd().parts:\n",
        "    os.chdir('..')\n",
        "elif not pathlib.Path('models').exists():\n",
        "  !git clone --depth 1 https://github.com/tensorflow/models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NwdsBdGhFanc"
      },
      "outputs": [],
      "source": [
        "# Install the Object Detection API\n",
        "%%bash\n",
        "cd models/research/\n",
        "protoc object_detection/protos/*.proto --python_out=.\n",
        "cp object_detection/packages/tf2/setup.py .\n",
        "python -m pip install ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uZcqD4NLdnf4"
      },
      "outputs": [],
      "source": [
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import os\n",
        "import random\n",
        "import io\n",
        "import imageio\n",
        "import glob\n",
        "import scipy.misc\n",
        "import numpy as np\n",
        "from six import BytesIO\n",
        "from PIL import Image, ImageDraw, ImageFont\n",
        "from IPython.display import display, Javascript\n",
        "from IPython.display import Image as IPyImage\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from object_detection.utils import label_map_util\n",
        "from object_detection.utils import config_util\n",
        "from object_detection.utils import visualization_utils as viz_utils\n",
        "from object_detection.utils import colab_utils\n",
        "from object_detection.utils import config_util\n",
        "from object_detection.builders import model_builder\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IogyryF2lFBL"
      },
      "source": [
        "##Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-y9R0Xllefec"
      },
      "outputs": [],
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: a file path.\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  img_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "  image = Image.open(BytesIO(img_data))\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "def plot_detections(image_np,\n",
        "                    boxes,\n",
        "                    classes,\n",
        "                    scores,\n",
        "                    category_index,\n",
        "                    figsize=(12, 16),\n",
        "                    image_name=None):\n",
        "  \"\"\"Wrapper function to visualize detections.\n",
        "\n",
        "  Args:\n",
        "    image_np: uint8 numpy array with shape (img_height, img_width, 3)\n",
        "    boxes: a numpy array of shape [N, 4]\n",
        "    classes: a numpy array of shape [N]. Note that class indices are 1-based,\n",
        "      and match the keys in the label map.\n",
        "    scores: a numpy array of shape [N] or None.  If scores=None, then\n",
        "      this function assumes that the boxes to be plotted are groundtruth\n",
        "      boxes and plot all boxes as black with no classes or scores.\n",
        "    category_index: a dict containing category dictionaries (each holding\n",
        "      category index `id` and category name `name`) keyed by category indices.\n",
        "    figsize: size for the figure.\n",
        "    image_name: a name for the image file.\n",
        "  \"\"\"\n",
        "  image_np_with_annotations = image_np.copy()\n",
        "  viz_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_np_with_annotations,\n",
        "      boxes,\n",
        "      classes,\n",
        "      scores,\n",
        "      category_index,\n",
        "      use_normalized_coordinates=True,\n",
        "      min_score_thresh=0.8)\n",
        "  if image_name:\n",
        "    plt.imsave(image_name, image_np_with_annotations)\n",
        "  else:\n",
        "    plt.imshow(image_np_with_annotations)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sSaXL28TZfk1"
      },
      "source": [
        "## Rubber Ducky data\n",
        "\n",
        "We will start with some toy data consisting of 5 images of a rubber\n",
        "ducky.  Note that the [COCO](https://cocodataset.org/#explore) dataset contains a number of animals, but notably, it does *not* contain rubber duckies (or even ducks for that matter), so this is a novel class."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SQy3ND7EpFQM"
      },
      "outputs": [],
      "source": [
        "# Load images and visualize\n",
        "train_image_dir = 'models/research/object_detection/test_images/ducky/train/'\n",
        "train_images_np = []\n",
        "for i in range(1, 6):\n",
        "  image_path = os.path.join(train_image_dir, 'robertducky' + str(i) + '.jpg')\n",
        "  train_images_np.append(load_image_into_numpy_array(image_path))\n",
        "\n",
        "plt.rcParams['axes.grid'] = False\n",
        "plt.rcParams['xtick.labelsize'] = False\n",
        "plt.rcParams['ytick.labelsize'] = False\n",
        "plt.rcParams['xtick.top'] = False\n",
        "plt.rcParams['xtick.bottom'] = False\n",
        "plt.rcParams['ytick.left'] = False\n",
        "plt.rcParams['ytick.right'] = False\n",
        "plt.rcParams['figure.figsize'] = [14, 7]\n",
        "\n",
        "for idx, train_image_np in enumerate(train_images_np):\n",
        "  plt.subplot(2, 3, idx+1)\n",
        "  plt.imshow(train_image_np)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LbOe9Ym7xMGV"
      },
      "source": [
        "# Transfer Learning\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dqb_yjAo3cO_"
      },
      "source": [
        "## Data Preparation\n",
        "\n",
        "First, we populate the groundtruth with pre-annotated bounding boxes.\n",
        "\n",
        "We then add the class annotations (for simplicity, we assume a single 'Duck' class in this colab; though it should be straightforward to extend this to handle multiple classes).  We also convert everything to the format that the training\n",
        "loop below expects (e.g., everything converted to tensors, classes converted to one-hot representations, etc.)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wIAT6ZUmdHOC"
      },
      "outputs": [],
      "source": [
        "gt_boxes = [\n",
        "            np.array([[0.436, 0.591, 0.629, 0.712]], dtype=np.float32),\n",
        "            np.array([[0.539, 0.583, 0.73, 0.71]], dtype=np.float32),\n",
        "            np.array([[0.464, 0.414, 0.626, 0.548]], dtype=np.float32),\n",
        "            np.array([[0.313, 0.308, 0.648, 0.526]], dtype=np.float32),\n",
        "            np.array([[0.256, 0.444, 0.484, 0.629]], dtype=np.float32)\n",
        "]\n",
        "\n",
        "# By convention, our non-background classes start counting at 1.  Given\n",
        "# that we will be predicting just one class, we will therefore assign it a\n",
        "# `class id` of 1.\n",
        "duck_class_id = 1\n",
        "num_classes = 1\n",
        "\n",
        "category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}\n",
        "\n",
        "# Convert class labels to one-hot; convert everything to tensors.\n",
        "# The `label_id_offset` here shifts all classes by a certain number of indices;\n",
        "# we do this here so that the model receives one-hot labels where non-background\n",
        "# classes start counting at the zeroth index.  This is ordinarily just handled\n",
        "# automatically in our training binaries, but we need to reproduce it here.\n",
        "label_id_offset = 1\n",
        "train_image_tensors = []\n",
        "gt_classes_one_hot_tensors = []\n",
        "gt_box_tensors = []\n",
        "for (train_image_np, gt_box_np) in zip(\n",
        "    train_images_np, gt_boxes):\n",
        "  train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(\n",
        "      train_image_np, dtype=tf.float32), axis=0))\n",
        "  gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))\n",
        "  zero_indexed_groundtruth_classes = tf.convert_to_tensor(\n",
        "      np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)\n",
        "  gt_classes_one_hot_tensors.append(tf.one_hot(\n",
        "      zero_indexed_groundtruth_classes, num_classes))\n",
        "print('Done prepping data.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b3_Z3mJWN9KJ"
      },
      "source": [
        "Let's just visualize the rubber duckies as a sanity check\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YBD6l-E4N71y"
      },
      "outputs": [],
      "source": [
        "dummy_scores = np.array([1.0], dtype=np.float32)  # give boxes a score of 100%\n",
        "\n",
        "plt.figure(figsize=(30, 15))\n",
        "for idx in range(5):\n",
        "  plt.subplot(2, 3, idx+1)\n",
        "  plot_detections(\n",
        "      train_images_np[idx],\n",
        "      gt_boxes[idx],\n",
        "      np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),\n",
        "      dummy_scores, category_index)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ghDAsqfoZvPh"
      },
      "source": [
        "## Load mobile-friendly model\n",
        "\n",
        "In this cell we build a mobile-friendly single-stage detection architecture (SSD MobileNet V2 FPN-Lite) and restore all but the classification layer at the top (which will be randomly initialized).\n",
        "\n",
        "**NOTE**: TensorFlow Lite only supports SSD models for now.\n",
        "\n",
        "For simplicity, we have hardcoded a number of things in this colab for the specific SSD architecture at hand (including assuming that the image size will always be 320x320), however it is not difficult to generalize to other model configurations (`pipeline.config` in the zip downloaded from the [Model Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.)).\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9J16r3NChD-7"
      },
      "outputs": [],
      "source": [
        "# Download the checkpoint and put it into models/research/object_detection/test_data/\n",
        "\n",
        "!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz\n",
        "!tar -xf ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz\n",
        "!if [ -d \"models/research/object_detection/test_data/checkpoint\" ]; then rm -Rf models/research/object_detection/test_data/checkpoint; fi\n",
        "!mkdir models/research/object_detection/test_data/checkpoint\n",
        "!mv ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8/checkpoint models/research/object_detection/test_data/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RyT4BUbaMeG-"
      },
      "outputs": [],
      "source": [
        "tf.keras.backend.clear_session()\n",
        "\n",
        "print('Building model and restoring weights for fine-tuning...', flush=True)\n",
        "num_classes = 1\n",
        "pipeline_config = 'models/research/object_detection/configs/tf2/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.config'\n",
        "checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'\n",
        "\n",
        "# This will be where we save checkpoint \u0026 config for TFLite conversion later.\n",
        "output_directory = 'output/'\n",
        "output_checkpoint_dir = os.path.join(output_directory, 'checkpoint')\n",
        "\n",
        "# Load pipeline config and build a detection model.\n",
        "#\n",
        "# Since we are working off of a COCO architecture which predicts 90\n",
        "# class slots by default, we override the `num_classes` field here to be just\n",
        "# one (for our new rubber ducky class).\n",
        "configs = config_util.get_configs_from_pipeline_file(pipeline_config)\n",
        "model_config = configs['model']\n",
        "model_config.ssd.num_classes = num_classes\n",
        "model_config.ssd.freeze_batchnorm = True\n",
        "detection_model = model_builder.build(\n",
        "      model_config=model_config, is_training=True)\n",
        "# Save new pipeline config\n",
        "pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)\n",
        "config_util.save_pipeline_config(pipeline_proto, output_directory)\n",
        "\n",
        "# Set up object-based checkpoint restore --- SSD has two prediction\n",
        "# `heads` --- one for classification, the other for box regression.  We will\n",
        "# restore the box regression head but initialize the classification head\n",
        "# from scratch (we show the omission below by commenting out the line that\n",
        "# we would add if we wanted to restore both heads)\n",
        "fake_box_predictor = tf.compat.v2.train.Checkpoint(\n",
        "    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,\n",
        "    # _prediction_heads=detection_model._box_predictor._prediction_heads,\n",
        "    #    (i.e., the classification head that we *will not* restore)\n",
        "    _box_prediction_head=detection_model._box_predictor._box_prediction_head,\n",
        "    )\n",
        "fake_model = tf.compat.v2.train.Checkpoint(\n",
        "          _feature_extractor=detection_model._feature_extractor,\n",
        "          _box_predictor=fake_box_predictor)\n",
        "ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)\n",
        "ckpt.restore(checkpoint_path).expect_partial()\n",
        "\n",
        "# To save checkpoint for TFLite conversion.\n",
        "exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)\n",
        "ckpt_manager = tf.train.CheckpointManager(\n",
        "    exported_ckpt, output_checkpoint_dir, max_to_keep=1)\n",
        "\n",
        "# Run model through a dummy image so that variables are created\n",
        "image, shapes = detection_model.preprocess(tf.zeros([1, 320, 320, 3]))\n",
        "prediction_dict = detection_model.predict(image, shapes)\n",
        "_ = detection_model.postprocess(prediction_dict, shapes)\n",
        "print('Weights restored!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pCkWmdoZZ0zJ"
      },
      "source": [
        "## Eager training loop (Fine-tuning)\n",
        "\n",
        "Some of the parameters in this block have been set empirically: for example, `learning_rate`, `num_batches` \u0026 `momentum` for SGD. These are just a starting point, you will have to tune these for your data \u0026 model architecture to get the best results.\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nyHoF4mUrv5-"
      },
      "outputs": [],
      "source": [
        "tf.keras.backend.set_learning_phase(True)\n",
        "\n",
        "# These parameters can be tuned; since our training set has 5 images\n",
        "# it doesn't make sense to have a much larger batch size, though we could\n",
        "# fit more examples in memory if we wanted to.\n",
        "batch_size = 5\n",
        "learning_rate = 0.15\n",
        "num_batches = 1000\n",
        "\n",
        "# Select variables in top layers to fine-tune.\n",
        "trainable_variables = detection_model.trainable_variables\n",
        "to_fine_tune = []\n",
        "prefixes_to_train = [\n",
        "  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',\n",
        "  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']\n",
        "for var in trainable_variables:\n",
        "  if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):\n",
        "    to_fine_tune.append(var)\n",
        "\n",
        "# Set up forward + backward pass for a single train step.\n",
        "def get_model_train_step_function(model, optimizer, vars_to_fine_tune):\n",
        "  \"\"\"Get a tf.function for training step.\"\"\"\n",
        "\n",
        "  # Use tf.function for a bit of speed.\n",
        "  # Comment out the tf.function decorator if you want the inside of the\n",
        "  # function to run eagerly.\n",
        "  @tf.function\n",
        "  def train_step_fn(image_tensors,\n",
        "                    groundtruth_boxes_list,\n",
        "                    groundtruth_classes_list):\n",
        "    \"\"\"A single training iteration.\n",
        "\n",
        "    Args:\n",
        "      image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.\n",
        "        Note that the height and width can vary across images, as they are\n",
        "        reshaped within this function to be 320x320.\n",
        "      groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with type\n",
        "        tf.float32 representing groundtruth boxes for each image in the batch.\n",
        "      groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]\n",
        "        with type tf.float32 representing groundtruth boxes for each image in\n",
        "        the batch.\n",
        "\n",
        "    Returns:\n",
        "      A scalar tensor representing the total loss for the input batch.\n",
        "    \"\"\"\n",
        "    shapes = tf.constant(batch_size * [[320, 320, 3]], dtype=tf.int32)\n",
        "    model.provide_groundtruth(\n",
        "        groundtruth_boxes_list=groundtruth_boxes_list,\n",
        "        groundtruth_classes_list=groundtruth_classes_list)\n",
        "    with tf.GradientTape() as tape:\n",
        "      preprocessed_images = tf.concat(\n",
        "          [detection_model.preprocess(image_tensor)[0]\n",
        "           for image_tensor in image_tensors], axis=0)\n",
        "      prediction_dict = model.predict(preprocessed_images, shapes)\n",
        "      losses_dict = model.loss(prediction_dict, shapes)\n",
        "      total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']\n",
        "      gradients = tape.gradient(total_loss, vars_to_fine_tune)\n",
        "      optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))\n",
        "    return total_loss\n",
        "\n",
        "  return train_step_fn\n",
        "\n",
        "optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)\n",
        "train_step_fn = get_model_train_step_function(\n",
        "    detection_model, optimizer, to_fine_tune)\n",
        "\n",
        "print('Start fine-tuning!', flush=True)\n",
        "for idx in range(num_batches):\n",
        "  # Grab keys for a random subset of examples\n",
        "  all_keys = list(range(len(train_images_np)))\n",
        "  random.shuffle(all_keys)\n",
        "  example_keys = all_keys[:batch_size]\n",
        "\n",
        "  # Note that we do not do data augmentation in this demo.  If you want a\n",
        "  # a fun exercise, we recommend experimenting with random horizontal flipping\n",
        "  # and random cropping :)\n",
        "  gt_boxes_list = [gt_box_tensors[key] for key in example_keys]\n",
        "  gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]\n",
        "  image_tensors = [train_image_tensors[key] for key in example_keys]\n",
        "\n",
        "  # Training step (forward pass + backwards pass)\n",
        "  total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)\n",
        "\n",
        "  if idx % 100 == 0:\n",
        "    print('batch ' + str(idx) + ' of ' + str(num_batches)\n",
        "    + ', loss=' +  str(total_loss.numpy()), flush=True)\n",
        "\n",
        "print('Done fine-tuning!')\n",
        "\n",
        "ckpt_manager.save()\n",
        "print('Checkpoint saved!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cYk1_9Fc2lZO"
      },
      "source": [
        "# Export \u0026 run with TensorFlow Lite\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y0nsDVEd9SuX"
      },
      "source": [
        "## Model Conversion\n",
        "\n",
        "First, we invoke the `export_tflite_graph_tf2.py` script to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.\n",
        "\n",
        "To know more about this process, please look at [this documentation](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dyrqHSQQ7WKE"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "python models/research/object_detection/export_tflite_graph_tf2.py \\\n",
        "  --pipeline_config_path output/pipeline.config \\\n",
        "  --trained_checkpoint_dir output/checkpoint \\\n",
        "  --output_directory tflite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m5hjPyR78bgs"
      },
      "outputs": [],
      "source": [
        "!tflite_convert --saved_model_dir=tflite/saved_model --output_file=tflite/model.tflite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WHlXL1x_Z3tc"
      },
      "source": [
        "## Test .tflite model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WcE6OwrHQJya"
      },
      "outputs": [],
      "source": [
        "test_image_dir = 'models/research/object_detection/test_images/ducky/test/'\n",
        "test_images_np = []\n",
        "for i in range(1, 50):\n",
        "  image_path = os.path.join(test_image_dir, 'out' + str(i) + '.jpg')\n",
        "  test_images_np.append(np.expand_dims(\n",
        "      load_image_into_numpy_array(image_path), axis=0))\n",
        "\n",
        "# Again, uncomment this decorator if you want to run inference eagerly\n",
        "def detect(interpreter, input_tensor):\n",
        "  \"\"\"Run detection on an input image.\n",
        "\n",
        "  Args:\n",
        "    interpreter: tf.lite.Interpreter\n",
        "    input_tensor: A [1, height, width, 3] Tensor of type tf.float32.\n",
        "      Note that height and width can be anything since the image will be\n",
        "      immediately resized according to the needs of the model within this\n",
        "      function.\n",
        "\n",
        "  Returns:\n",
        "    A dict containing 3 Tensors (`detection_boxes`, `detection_classes`,\n",
        "      and `detection_scores`).\n",
        "  \"\"\"\n",
        "  input_details = interpreter.get_input_details()\n",
        "  output_details = interpreter.get_output_details()\n",
        "\n",
        "  # We use the original model for pre-processing, since the TFLite model doesn't\n",
        "  # include pre-processing.\n",
        "  preprocessed_image, shapes = detection_model.preprocess(input_tensor)\n",
        "  interpreter.set_tensor(input_details[0]['index'], preprocessed_image.numpy())\n",
        "\n",
        "  interpreter.invoke()\n",
        "\n",
        "  boxes = interpreter.get_tensor(output_details[0]['index'])\n",
        "  classes = interpreter.get_tensor(output_details[1]['index'])\n",
        "  scores = interpreter.get_tensor(output_details[2]['index'])\n",
        "  return boxes, classes, scores\n",
        "\n",
        "# Load the TFLite model and allocate tensors.\n",
        "interpreter = tf.lite.Interpreter(model_path=\"tflite/model.tflite\")\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "# Note that the first frame will trigger tracing of the tf.function, which will\n",
        "# take some time, after which inference should be fast.\n",
        "\n",
        "label_id_offset = 1\n",
        "for i in range(len(test_images_np)):\n",
        "  input_tensor = tf.convert_to_tensor(test_images_np[i], dtype=tf.float32)\n",
        "  boxes, classes, scores = detect(interpreter, input_tensor)\n",
        "\n",
        "  plot_detections(\n",
        "      test_images_np[i][0],\n",
        "      boxes[0],\n",
        "      classes[0].astype(np.uint32) + label_id_offset,\n",
        "      scores[0],\n",
        "      category_index, figsize=(15, 20), image_name=\"gif_frame_\" + ('%02d' % i) + \".jpg\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZkMPOSQE0x8C"
      },
      "outputs": [],
      "source": [
        "imageio.plugins.freeimage.download()\n",
        "\n",
        "anim_file = 'duckies_test.gif'\n",
        "\n",
        "filenames = glob.glob('gif_frame_*.jpg')\n",
        "filenames = sorted(filenames)\n",
        "last = -1\n",
        "images = []\n",
        "for filename in filenames:\n",
        "  image = imageio.imread(filename)\n",
        "  images.append(image)\n",
        "\n",
        "imageio.mimsave(anim_file, images, 'GIF-FI', fps=5)\n",
        "\n",
        "display(IPyImage(open(anim_file, 'rb').read()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yzaHWsS58_PQ"
      },
      "source": [
        "## (Optional) Download model\n",
        "\n",
        "This model can be run on-device with **TensorFlow Lite**. Look at [our SSD model signature](https://www.tensorflow.org/lite/models/object_detection/overview#uses_and_limitations) to understand how to interpret the model IO tensors. Our [Object Detection example](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection) is a good starting point for integrating the model into your mobile app.\n",
        "\n",
        "Refer to TFLite's [inference documentation](https://www.tensorflow.org/lite/guide/inference) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gZ6vac3RAY3j"
      },
      "outputs": [],
      "source": [
        "from google.colab import files\n",
        "files.download('tflite/model.tflite') "
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "eager_few_shot_od_training_tflite.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
